{% import 'macro.jinja2' as MRO with context %}
{% if train_mode == "performance" %}
{{ MRO.insert_tab()}}# performance space
{{ MRO.insert_tab()}}def search_space():
{{ MRO.insert_tab()}}    space = HyperSpace()
{{ MRO.insert_tab()}}    with space.as_default():
{{ MRO.insert_tab()}}        p_nets = MultipleChoice(
{{ MRO.insert_tab()}}            ['dnn_nets', 'linear', 'cin_nets', 'fm_nets', 'afm_nets',
{{ MRO.insert_tab()}}             'cross_nets', 'cross_dnn_nets', 'dcn_nets',
{{ MRO.insert_tab()}}             'autoint_nets',  'fibi_dnn_nets'], num_chosen_least=3)  # 'fgcnn_dnn_nets', 'pnn_nets',
{{ MRO.insert_tab()}}        dt_module = DTModuleSpace(
{{ MRO.insert_tab()}}            nets=p_nets,
{{ MRO.insert_tab()}}            auto_categorize=Bool(),
{{ MRO.insert_tab()}}            cat_remain_numeric=Bool(),
{{ MRO.insert_tab()}}            auto_discrete=Bool(),
{{ MRO.insert_tab()}}            # apply_gbm_features=Bool(),
{{ MRO.insert_tab()}}            # gbm_feature_type=Choice([DT_consts.GBM_FEATURE_TYPE_DENSE, DT_consts.GBM_FEATURE_TYPE_EMB]),
{{ MRO.insert_tab()}}            embeddings_output_dim=Choice([4, 10, 20]),
{{ MRO.insert_tab()}}            embedding_dropout=Choice([0, 0.1, 0.2, 0.3, 0.4, 0.5]),
{{ MRO.insert_tab()}}            stacking_op=Choice([DT_consts.STACKING_OP_ADD, DT_consts.STACKING_OP_CONCAT]),
{{ MRO.insert_tab()}}            output_use_bias=Bool(),
{{ MRO.insert_tab()}}            apply_class_weight=Bool(),
{{ MRO.insert_tab()}}            earlystopping_patience=Choice({{ earlystopping_patience }})
{{ MRO.insert_tab()}}        )
{{ MRO.insert_tab()}}        dnn = DnnModule()(dt_module)
{{ MRO.insert_tab()}}        fit = DTFit(batch_size=Choice([64, 128]))(dt_module)
{{ MRO.insert_tab()}}    return space
{% elif  train_mode == "quick" %}
{{ MRO.insert_tab()}}# quick space
{{ MRO.insert_tab()}}def search_space():
{{ MRO.insert_tab()}}    space = HyperSpace()
{{ MRO.insert_tab()}}    with space.as_default():
{{ MRO.insert_tab()}}        p_nets = MultipleChoice(
{{ MRO.insert_tab()}}            ['dnn_nets', 'linear', 'fm_nets'], num_chosen_least=1)
{{ MRO.insert_tab()}}        dt_module = DTModuleSpace(
{{ MRO.insert_tab()}}            nets=p_nets,
{{ MRO.insert_tab()}}            auto_categorize=Bool(),
{{ MRO.insert_tab()}}            cat_remain_numeric=Bool(),
{{ MRO.insert_tab()}}            auto_discrete=Bool(),
{{ MRO.insert_tab()}}            # apply_gbm_features=Bool(),
{{ MRO.insert_tab()}}            # gbm_feature_type=Choice([DT_consts.GBM_FEATURE_TYPE_DENSE, DT_consts.GBM_FEATURE_TYPE_EMB]),
{{ MRO.insert_tab()}}            embeddings_output_dim=Choice([4, 10]),
{{ MRO.insert_tab()}}            embedding_dropout=Choice([0, 0.5]),
{{ MRO.insert_tab()}}            stacking_op=Choice([DT_consts.STACKING_OP_ADD, DT_consts.STACKING_OP_CONCAT]),
{{ MRO.insert_tab()}}            output_use_bias=Bool(),
{{ MRO.insert_tab()}}            apply_class_weight=Bool(),
{{ MRO.insert_tab()}}            earlystopping_patience=Choice({{ earlystopping_patience }})
{{ MRO.insert_tab()}}        )
{{ MRO.insert_tab()}}        dnn = DnnModule(dnn_units=Choice([100, 200]),
{{ MRO.insert_tab()}}                        reduce_factor=Choice([1, 0.8]),
{{ MRO.insert_tab()}}                        dnn_dropout=Choice([0, 0.3]),
{{ MRO.insert_tab()}}                        use_bn=Bool(),
{{ MRO.insert_tab()}}                        dnn_layers=2,
{{ MRO.insert_tab()}}                        activation='relu')(dt_module)
{{ MRO.insert_tab()}}        fit = DTFit(batch_size=Choice([32, 64]), epochs=500)(dt_module)
{{ MRO.insert_tab()}}
{{ MRO.insert_tab()}}    return space
{% else %}
{{ MRO.insert_tab()}}# minimal space
{{ MRO.insert_tab()}}def search_space():
{{ MRO.insert_tab()}}    space = HyperSpace()
{{ MRO.insert_tab()}}    with space.as_default():
{{ MRO.insert_tab()}}        dt_module = DTModuleSpace(
{{ MRO.insert_tab()}}            nets=Choice([['linear']]),
{{ MRO.insert_tab()}}            auto_categorize=Choice([False]),
{{ MRO.insert_tab()}}            cat_remain_numeric=Choice([False]),
{{ MRO.insert_tab()}}            auto_discrete=Choice([False]),
{{ MRO.insert_tab()}}            # apply_gbm_features=Bool(),
{{ MRO.insert_tab()}}            # gbm_feature_type=Choice([DT_consts.GBM_FEATURE_TYPE_DENSE, DT_consts.GBM_FEATURE_TYPE_EMB]),
{{ MRO.insert_tab()}}            embeddings_output_dim=Choice([4]),
{{ MRO.insert_tab()}}            embedding_dropout=Choice([0]),
{{ MRO.insert_tab()}}            stacking_op=Choice([DT_consts.STACKING_OP_ADD]),
{{ MRO.insert_tab()}}            output_use_bias=Choice([False]),
{{ MRO.insert_tab()}}            apply_class_weight=Choice([False]),
{{ MRO.insert_tab()}}            earlystopping_patience=Choice([1]))
{{ MRO.insert_tab()}}        dnn = DnnModule(dnn_units=Choice([10]),
{{ MRO.insert_tab()}}                        reduce_factor=Choice([1, 0.8]),
{{ MRO.insert_tab()}}                        dnn_dropout=Choice([0]),
{{ MRO.insert_tab()}}                        use_bn=Bool(),
{{ MRO.insert_tab()}}                        dnn_layers=2,
{{ MRO.insert_tab()}}                        activation='relu')(dt_module)
{{ MRO.insert_tab()}}        fit = DTFit(batch_size=Choice([64]), epochs=2)(dt_module)
{{ MRO.insert_tab()}}    return space
{% endif %}
{{ MRO.insert_tab()}}
{{ MRO.insert_tab()}}rs = MCTSSearcher(search_space, max_node_space=10, optimize_direction=optimize_direction)
{{ MRO.insert_tab()}}
{{ MRO.insert_tab()}}hk = HyperDT(rs,
{{ MRO.insert_tab()}}             callbacks=[{% if target_source_type == 'raw_python' %}TrainTrialCallback(server_portal, train_job_name, dataset_name), {% endif %}SummaryCallback(), FileLoggingCallback(rs), EarlyStoppingCallback(max_no_improvement_trials=5, mode=optimize_direction.value)],
{{ MRO.insert_tab()}}             reward_metric=reward_metric,
{{ MRO.insert_tab()}}             dnn_params={
{{ MRO.insert_tab()}}                 'dnn_units': ((256, 0, False), (256, 0, False)),
{{ MRO.insert_tab()}}                 'dnn_activation': 'relu'})
{{ MRO.insert_tab()}}hk.search(X_train, y_train, X_eval=X_train, y_eval=y_train, max_trials=max_trials)
{{ MRO.insert_tab()}}best_trial = hk.get_best_trial()
{{ MRO.insert_tab()}}estimator = hk.final_train(best_trial.space_sample, X_train, y_train)
